package it.uniroma2.tk;

import it.uniroma2.util.math.ArrayMath;
import it.uniroma2.util.tree.Tree;
import it.uniroma2.util.vector.SemanticVectorProvider;

import java.io.File;
import java.util.HashMap;
import java.util.Vector;

import edu.berkeley.compbio.jlibsvm.kernel.KernelFunction;

/**
 * @author Fabio Massimo Zanzotto
 * This class implements the Smoothed Tree Kernel computation between Tree objects as introduced in 
 * (Mehdad, Y.; Moschitti, A. & Zanzotto, F. M. Syntactic/Semantic Structures for Textual Entailment Recognition, NAACL, 2010)
 * 
 */
public class SmoothedTreeKernel implements KernelFunction<Tree> {

	public static double lambda = 1;
	public static boolean lexicalized = false;
	private static int nodeCount = 0;
	private static HashMap<String,Double> deltaMatrix; 
	private static HashMap<Tree,Integer> nodeIndices; 
	
	private static SemanticVectorProvider ds = null;
	
	public static void initializeSemanticVectors(int vectorSize,File distributionalDictionary) throws Exception {
		ds = new SemanticVectorProvider(vectorSize, distributionalDictionary);
	}
	
	public static double value(Tree a, Tree b) throws Exception {
		deltaMatrix = new HashMap<String,Double>();
		nodeIndices = new HashMap<Tree,Integer>();
		nodeCount = 0;
		double sum = 0;
		for (Tree aa : allNodes(a))
			for (Tree bb : allNodes(b))
				sum += delta(aa,bb);
		return sum;
	}

	private static double delta(Tree a,Tree b) throws Exception {
		double k = 0;
		if (!nodeIndices.containsKey(a)) {
			nodeIndices.put(a,nodeCount);
			nodeCount++;
		}
		if (!nodeIndices.containsKey(b)) {
			nodeIndices.put(b,nodeCount);
			nodeCount++;
		}
		if (deltaMatrix.containsKey(nodeIndices.get(a) + ":" +nodeIndices.get(b))) {
			return deltaMatrix.get(nodeIndices.get(a) + ":" +nodeIndices.get(b));
		}

		if (a.getChildren().size() == b.getChildren().size()) {
			if (a.getChildren().size() == 1 && a.getChildren().get(0).isTerminal() && b.getChildren().get(0).isTerminal()) {
				if (lexicalized) 
					//&& a.equals(b))
					k = ArrayMath.cosine(ds.getVector(a.getRootLabel()),ds.getVector(b.getRootLabel()));
			} else {
				
				//if (productionCompare(a, b)) {
				k = ArrayMath.cosine(ds.getVector(a.getRootLabel()),ds.getVector(b.getRootLabel()));
				for (int i=0; i<a.getChildren().size(); i++) {
					k = k*(1+lambda*delta(a.getChildren().get(i),b.getChildren().get(i)));
					}
				//} 
			}
		}
		deltaMatrix.put(nodeIndices.get(a) + ":" +nodeIndices.get(b),k);
		return k;
	}
	
	
	private static Vector<Tree> allNodes(Tree node) {
		Vector<Tree> all = new Vector<Tree>();
		all.add(node);
		for (Tree child : node.getChildren())
			all.addAll(allNodes(child));
		return all;
	}

	public double evaluate(Tree arg0, Tree arg1) {
		try {
			return value(arg0, arg1);
		} catch (Exception e) {
			e.printStackTrace();
			return 0;
		}
	}
	
}
